# The code for AD calculation was acquired from: https://github.com/marcossantanaioc/qsar_ad/tree/master/

__all__ = ['BaseDomain', 'kNNDomain', 'SVMDomain']

import pandas as pd
import numpy as np
from rdkit import Chem
from scipy.spatial import distance

class BaseDomain:
     def calculate_applicability_domain(self):
        pass

class kNNDomain(BaseDomain):
    def __init__(self, Xref:np.array, metric='euclidean'):
        self.Xref = Xref
        self.metric = metric
        self.ad_threshold = self.calculate_ad_threhold(metric=metric)

    @property
    def ad_threshold(self):
        return self._ad_threshold

    @ad_threshold.setter
    def ad_threshold(self, v):
        self._ad_threshold = v

    def calculate_similarity_from_array(self, fp1:np.array, fp2:np.array=None, metric:str=None, z=0.5):
        from scipy.spatial import distance
        if fp2 is None:
            simi_matrix = distance.cdist(fp1, fp1, metric=metric).squeeze()
        else:
            simi_matrix = distance.cdist(fp1, fp2, metric=metric).squeeze()
        return simi_matrix

    def calculate_ad_threhold(self, X:np.array=None,  metric:str=None, z=0.5):

        from scipy.spatial import distance

        X = self.Xref if X is None else X
        simi_matrix = self.calculate_similarity_from_array(X, metric=metric).squeeze()

        std_distances = np.std(simi_matrix)
        avg_distances = np.mean(simi_matrix)

        ad_threshold = (z*std_distances) + avg_distances
        return ad_threshold

    def get_knn(self, fp:np.array, ref_fp:np.array=None, k:int=5):

        ref_fp = self.Xref if ref_fp is None else ref_fp
        distances = self.calculate_similarity_from_array(fp, ref_fp, metric=self.metric).reshape(len(fp), len(ref_fp))
        neighbours = np.argsort(distances, axis=-1).reshape(len(fp), len(ref_fp))
        return distances, neighbours


    def calculate_applicability_domain(self, fp:np.array, ref_fp:np.array=None, k:int=5):
        assert k >= 1
        ref_fp = self.Xref if ref_fp is None else ref_fp
        distances, neighbours = self.get_knn(fp, ref_fp, k=k)
        avg_distance = np.take_along_axis(distances, neighbours[:, :k], 1).mean(-1)
        return (avg_distance, avg_distance<=self.ad_threshold)


class SVMDomain:
    from sklearn.svm import OneClassSVM
    def __init__(self, Xref:np.array, svm_model=None):
        self.Xref = Xref
        self.svm_model = OneClassSVM() if svm_model is None else svm_model
    @property
    def svm_model(self):
        return self._svm_model

    @svm_model.setter
    def svm_model(self, v):
        self._svm_model = v

    def train_model(self, X:np.array=None, params={}):
        X = self.Xref if X is None else X
        self.svm_model.set_params(**params)
        self.svm_model.fit(X)
        return self.svm_model